# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn
from loguru import logger


def test_mochi_diffusers_pipeline():
    """
    Test that loads the Diffusers Mochi pipeline and executes it on a prompt.
    This test is useful for investigating the internal workings of the pipeline.
    """
    try:
        from diffusers import MochiPipeline
        from diffusers.utils import export_to_video
    except ImportError:
        pytest.skip("diffusers library not available or MochiPipeline not found")

    # Set device and dtype
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32

    logger.info(f"Loading Mochi pipeline on device: {device} with dtype: {torch_dtype}")

    # Load the Mochi pipeline
    pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch_dtype)
    pipe.to(device)

    # Print out the constituent models of the pipeline
    logger.info("=== Mochi Pipeline Constituent Models ===")
    for attr_name in dir(pipe):
        if not attr_name.startswith("_"):
            try:
                attr_value = getattr(pipe, attr_name)
                if hasattr(attr_value, "__class__") and hasattr(attr_value.__class__, "__module__"):
                    # Check if it's likely a model component (has parameters or is a neural network module)
                    if (
                        hasattr(attr_value, "parameters")
                        or "torch.nn" in str(type(attr_value))
                        or "diffusers" in str(type(attr_value))
                    ):
                        logger.info(f"  {attr_name}: {type(attr_value)} - {attr_value.__class__.__module__}")
                        # Print additional info for key components
                        if hasattr(attr_value, "config"):
                            try:
                                config_keys = (
                                    list(attr_value.config.keys()) if hasattr(attr_value.config, "keys") else "N/A"
                                )
                                logger.info(f"    Config keys: {config_keys}")
                            except Exception as e:
                                logger.info(f"    Config keys: Error accessing config - {e}")
            except AttributeError as e:
                # Skip attributes that raise AttributeError when accessed
                logger.debug(f"  Skipped {attr_name}: {e}")
            except Exception as e:
                logger.debug(f"  Error accessing {attr_name}: {e}")

    # Also check the pipeline's components attribute if it exists
    if hasattr(pipe, "components"):
        logger.info("\n=== Pipeline Components Dictionary ===")
        for component_name, component in pipe.components.items():
            logger.info(f"  {component_name}: {type(component)}")
            if hasattr(component, "config"):
                try:
                    logger.info(f"    Config: {type(component.config)}")
                except Exception as e:
                    logger.info(f"    Config: Error accessing config - {e}")

    # Enable memory optimizations if on GPU
    if device == "cuda":
        pipe.enable_model_cpu_offload()
        pipe.enable_vae_tiling()

    # Define test prompt
    prompt = "A close-up of a beautiful butterfly landing on a flower, wings gently moving in the breeze."

    logger.info(f"Generating video with prompt: '{prompt}'")

    # Generate frames with reduced parameters for faster testing
    frames = pipe(
        prompt,
        num_inference_steps=10,  # Reduced for faster testing
        guidance_scale=3.5,
        num_frames=16,  # Reduced for faster testing
        height=320,  # Reduced resolution for faster testing
        width=320,  # Reduced resolution for faster testing
    ).frames[0]

    # Validate output
    assert frames is not None, "No frames were generated by the pipeline"
    assert len(frames) > 0, "Empty frames list generated by the pipeline"
    # assert len(frames) == 16, f"Expected 16 frames, got {len(frames)}"

    # Check frame properties
    first_frame = frames[0]
    logger.info(f"Generated {len(frames)} frames, first frame size: {first_frame.size}")

    # Optional: Export to video file for manual inspection
    try:
        export_to_video(frames, "mochi_test_output.mp4", fps=8)
    except AttributeError as e:
        logger.info(f"AttributeError: {e}")
    logger.info("Video exported to mochi_test_output.mp4")

    logger.info("Mochi pipeline test completed successfully!")


@pytest.mark.parametrize(
    "mesh_device, sp_axis, tp_axis, vae_mesh_shape, vae_sp_axis, vae_tp_axis, num_links",
    [
        # VAE mesh shape = (1, 8) is more memory efficient.
        [(1, 8), 1, 0, (1, 8), 0, 1, 1],
        [(2, 4), 0, 1, (1, 8), 0, 1, 1],
        [(4, 8), 1, 0, (4, 8), 0, 1, 4],  # note sp <-> tp switch for VAE for memory efficiency.
    ],
    ids=[
        "dit_1x8sp1tp0_vae_1x8sp0tp1",
        "dit_2x4sp0tp1_vae_1x8sp0tp1",
        "dit_4x8sp1tp0_vae_4x8sp0tp1",
    ],
    indirect=["mesh_device"],
)
@pytest.mark.parametrize("device_params", [{"fabric_config": ttnn.FabricConfig.FABRIC_1D}], indirect=True)
def test_tt_mochi_pipeline(
    mesh_device: ttnn.MeshDevice,
    sp_axis: int,
    tp_axis: int,
    vae_mesh_shape: tuple,
    vae_sp_axis: int,
    vae_tp_axis: int,
    num_links: int,
):
    """
    Test that creates the modified TT MochiPipeline and runs it on a prompt.
    This uses the TT transformer instead of the diffusers one.
    """
    try:
        from ....pipelines.mochi.pipeline_mochi import MochiPipeline as TTMochiPipeline
        from ....parallel.config import DiTParallelConfig, MochiVAEParallelConfig, ParallelFactor
    except ImportError as e:
        pytest.skip(f"Required TT modules not available: {e}")

    sp_factor = tuple(mesh_device.shape)[sp_axis]
    tp_factor = tuple(mesh_device.shape)[tp_axis]

    logger.info(
        f"Creating TT Mochi pipeline with DiT mesh device shape {mesh_device.shape}, VAE mesh device shape {vae_mesh_shape}"
    )
    logger.info(f"DiT SP axis: {sp_axis}, TP axis: {tp_axis}")
    logger.info(f"VAE SP axis: {vae_sp_axis}, TP axis: {tp_axis}")

    # Create parallel config
    parallel_config = DiTParallelConfig(
        cfg_parallel=ParallelFactor(factor=1, mesh_axis=0),
        tensor_parallel=ParallelFactor(factor=tp_factor, mesh_axis=tp_axis),
        sequence_parallel=ParallelFactor(factor=sp_factor, mesh_axis=sp_axis),
    )

    if vae_mesh_shape[vae_sp_axis] == 1:
        w_parallel_factor = 1
    else:
        w_parallel_factor = 2

    vae_parallel_config = MochiVAEParallelConfig(
        time_parallel=ParallelFactor(factor=vae_mesh_shape[vae_tp_axis], mesh_axis=vae_tp_axis),
        w_parallel=ParallelFactor(factor=w_parallel_factor, mesh_axis=vae_sp_axis),
        h_parallel=ParallelFactor(factor=vae_mesh_shape[vae_sp_axis] // w_parallel_factor, mesh_axis=vae_sp_axis),
    )
    assert vae_parallel_config.h_parallel.factor * vae_parallel_config.w_parallel.factor == vae_mesh_shape[vae_sp_axis]
    assert vae_parallel_config.h_parallel.mesh_axis == vae_parallel_config.w_parallel.mesh_axis

    # Create the TT Mochi pipeline
    tt_pipe = TTMochiPipeline(
        mesh_device=mesh_device,
        vae_mesh_shape=vae_mesh_shape,
        parallel_config=parallel_config,
        vae_parallel_config=vae_parallel_config,
        num_links=num_links,
        use_cache=True,
        use_reference_vae=False,
        model_name="genmo/mochi-1-preview",
    )

    # Use a generator for deterministic results.
    generator = torch.Generator("cpu").manual_seed(0)

    # Define test prompt (same as the diffusers test)
    prompt = "A close-up of a beautiful butterfly landing on a flower, wings gently moving in the breeze."

    logger.info(f"Generating video with TT pipeline using prompt: '{prompt}'")

    # Generate frames with reduced parameters for faster testing
    frames = tt_pipe(
        prompt,
        num_inference_steps=50,  # Reduced for faster testing
        guidance_scale=3.5,
        num_frames=168,  # Reduced for faster testing
        height=480,  # Reduced resolution for faster testing
        width=848,  # Reduced resolution for faster testing
        generator=generator,
    ).frames[0]

    # Validate output
    assert frames is not None, "No frames were generated by the TT pipeline"
    assert len(frames) > 0, "Empty frames list generated by the TT pipeline"

    # Check frame properties
    first_frame = frames[0]
    logger.info(f"TT Pipeline generated {len(frames)} frames, first frame size: {first_frame.size}")

    # Optional: Export to video file for comparison
    try:
        from diffusers.utils import export_to_video

        export_to_video(frames, "tt_mochi_test_output.mp4", fps=30)
        logger.info("TT Pipeline video exported to tt_mochi_test_output.mp4")
    except ImportError:
        logger.info("Could not export video - diffusers.utils.export_to_video not available")
    except AttributeError as e:
        logger.info(f"AttributeError: {e}")

    logger.info("TT Mochi pipeline test completed successfully!")
